dates[0], dates[-1]('2021-04-05', '2023-10-04')
[Array(0.88235294, dtype=float64),
Array(0.42857143, dtype=float64),
Array(0.20802005, dtype=float64),
Array(0.31587057, dtype=float64),
Array(0.74633431, dtype=float64),
Array(1.09724473, dtype=float64),
Array(0.63493344, dtype=float64)]
np1 = 158
i_start = 260 # 231
i_end = i_start + np1
n_delay = 6
y = h[i_start : i_start + np1, :n_delay]
p0_hat = y[:7].sum() / I[i_start : i_start + 7].sum()
n_weekday = 2
aux = (np1, n_delay, n_weekday, I[i_start : i_start + np1])
theta_manual = jnp.log(
jnp.array(
[1**2, 1**2, 0.1**2, 0.1**2, p0_hat] # s2_p # s2_q # s2_W # p0 # s2_0
)
)
model_manual = hospitalization_model(theta_manual, aux) message: Optimization terminated successfully.
success: True
status: 0
fun: 5.256827861654503
x: [-8.761e+00 -4.908e+00 -1.072e+01 -3.119e+00 -3.532e+00]
nit: 36
jac: [-1.549e-06 6.966e-07 4.007e-07 2.218e-06 5.756e-06]
hess_inv: [[ 5.288e+01 -7.922e-03 ... -1.135e+00 2.247e+00]
[-7.922e-03 5.991e+00 ... 2.864e-01 -1.159e-01]
...
[-1.135e+00 2.864e-01 ... 1.022e+02 -5.347e+00]
[ 2.247e+00 -1.159e-01 ... -5.347e+00 3.627e+01]]
nfev: 440
njev: 40

from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
modified_efficient_importance_sampling as MEIS,
)
from isssm.importance_sampling import pgssm_importance_sampling, ess_pct
import jax.random as jrn
proposal_la, info_la = LA(y, model0, 100)
key = jrn.PRNGKey(423423423)
key, subkey = jrn.split(key)
proposal_meis, info_meis = MEIS(
y, model0, proposal_la.z, proposal_la.Omega, 100, 10000, subkey
)
key, subkey = jrn.split(key)
samples, log_weights = pgssm_importance_sampling(
y, model0, proposal_meis.z, proposal_meis.Omega, 1000, subkey
)
ess_pct(log_weights)Array(85.38643487, dtype=float64)
Deviate from standard setup: now missing indices of \(B\) differ from those of \(Y\).
test: make everything in the last week of observations missing, except initial observation
y_nan = make_y_nan(y)
# need fewer data for missingness
np1_miss = 100
y_nan = y_nan[-np1_miss:]
# y_nan = y_nan[:83]
np1_miss, _ = y_nan.shape
missing_y_indices = jnp.isnan(y_nan)
missing_s_indicies = jnp.concatenate(
(jnp.full((np1_miss, 1), False, dtype=bool), missing_y_indices[:, :-1]), axis=-1
)
aux_miss = (np1_miss, n_delay, n_weekday, I[i_start + (np1 - np1_miss) : i_start + np1])
_, y_miss = account_for_nans(
hospitalization_model(theta0, aux_miss),
y_nan,
missing_y_indices,
missing_s_indicies,
)
_model_miss = lambda theta, aux: account_for_nans(
hospitalization_model(theta, aux), y_nan, missing_y_indices, missing_s_indicies
)[0] message: Desired error not necessarily achieved due to precision loss.
success: False
status: 2
fun: 4.61938723602526
x: [-8.761e+00 -4.613e+00 -1.072e+01 -3.055e+00 -3.662e+00]
nit: 2
jac: [ 3.004e-04 -7.560e-02 -7.003e-04 -1.635e-02 3.963e-02]
hess_inv: [[ 1.000e+00 -1.310e-02 ... -2.867e-03 5.721e-03]
[-1.310e-02 8.565e+00 ... 1.681e+00 -3.772e+00]
...
[-2.867e-03 1.681e+00 ... 1.371e+00 -8.318e-01]
[ 5.721e-03 -3.772e+00 ... -8.318e-01 2.853e+00]]
nfev: 582
njev: 52

message: Optimization terminated successfully.
success: True
status: 0
fun: 5.641690135301365
x: [-7.624e+00 -4.882e+00 -2.155e+01 -3.572e+00 -4.286e+00]
nit: 43
jac: [ 1.113e-06 -7.624e-06 2.420e-08 -3.275e-06 -4.825e-06]
hess_inv: [[ 4.046e+01 7.956e-01 ... 2.149e+00 1.062e+00]
[ 7.956e-01 3.755e+00 ... -2.142e+00 -2.584e-01]
...
[ 2.149e+00 -2.142e+00 ... 3.599e+01 -3.441e+00]
[ 1.062e+00 -2.584e-01 ... -3.441e+00 8.980e+00]]
nfev: 495
njev: 45
model_miss0 = _model_miss(theta0_missing, aux_miss)
proposal_la, info_la = LA_missing(y_miss, model_miss0, 10000, eps=1e-10)
plt.figure(figsize=(20, 8))
plt.title(
f"Min. eigenvalue: {jnp.linalg.eigvalsh(proposal_la.Omega).min():.2f}, convereged in {info_la.n_iter} iterations"
)
plt.imshow(jnp.linalg.eigvalsh(proposal_la.Omega).T)
plt.colorbar()/var/folders/9y/xdxkkt710kx5tf1j0p68y46r0000gn/T/ipykernel_52604/869667675.py:70: DeprecationWarning: The rcond argument for linalg.pinv is deprecated. Please use rtol instead.
Omega = jnp.linalg.pinv(Gamma, hermitian=True, rcond=1e-5)

[3325.93455402 1295.05228515 358.03265507 0. 0.
0. ]
[3326. 1295. 358. 0. 0. 0.]
-6.177065316487707

Array(0.0823388, dtype=float64)

post = mc_integration(samples, log_weights)
post_state = mc_integration(
vmap(state_mode, (None, 0))(model_miss0, samples), log_weights
)
fig, axs = plt.subplots(3, 2, figsize=(10, 10))
axs = axs.flatten()
axs[0].plot(
jnp.exp(post[:, 0]) * I[i_start + (np1 - np1_miss) : i_start + np1],
label="predicted",
)
axs[0].plot(y_miss[-(np1_miss):].sum(axis=-1), label="truth missing")
axs[0].plot(y[-(np1_miss):].sum(axis=-1), label="truth")
axs[0].legend()
axs[1].plot(jnp.exp(post[:, 0]))
axs[2].plot(from_consecutive_logits(post[:, 1:]), linestyle="--")
axs[2].plot(from_consecutive_logits(post_state[:, 1:8]))
axs[3].plot(post_state[:, 1:8])
axs[4].plot(post_state[:, 8])
axs[4].plot(post_state[:, 8 + 6])
plt.show()
from isssm.importance_sampling import prediction
def f_nowcast(x, s, y):
return jnp.sum((missing_y_indices * y) + (1 - missing_y_indices) * y_miss, axis=-1)
key, subkey = jrn.split(key)
preds = prediction(
f_nowcast,
y_miss,
proposal_la,
_model_miss(theta0_missing, aux_miss),
10000,
subkey,
jnp.array([0.025, 0.5, 0.975]),
hospitalization_model(theta0_missing, aux_miss),
)